"""
Martini Glass Conjecture Tester (streaming)

This script tests the conjecture:

    The midpoint (MP) between consecutive primes is divisible by
    a twin prime midpoint (MT), unless:
      1. The midpoint itself is a twin prime midpoint, OR
      2. The prime gap Δ is itself a twin midpoint or a multiple of one.

Features:
- Uses gmpy2 for fast prime generation and low memory use
- Streams primes without storing them all
- Logs results to CSV
- Prints progress every N tests
- Prints any counterexamples immediately

Usage:
    python martini_conjecture_tester.py --start 1000000 --finish 10000000 \
        --output results.csv --report-interval 1000

Arguments:
    --start            First number to test from (default=2)
    --finish           Last number to test up to (default=100_000_000)
    --output           CSV file path for results
    --report-interval  How many tests between progress reports (default=100000)
"""

import gmpy2
import csv
import time
import os
import argparse

def bosh_test(start=2, finish=1_000_000,
              output_file="results.csv",
              report_interval=100_000):

    start_time = time.time()
    tested = 0
    satisfied = 0
    fails = []

    # Twin midpoints set (dynamic: we only need them up to current midpoint)
    twin_midpoints = set()
    last_twin_check = 3

    def update_twin_midpoints(up_to):
        nonlocal last_twin_check
        p = last_twin_check
        while p <= up_to:
            if gmpy2.is_prime(p) and gmpy2.is_prime(p+2):
                twin_midpoints.add(p+1)
            p = int(gmpy2.next_prime(p))
        last_twin_check = up_to

    # Prepare CSV
    os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
    with open(output_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["p1", "p2", "gap", "midpoint", "status", "twin_divisors"])

    # Prime streaming
    p1 = int(gmpy2.next_prime(start-1))   # first prime ≥ start
    p2 = int(gmpy2.next_prime(p1))

    last_report = 0
    while p2 <= finish:
        gap = p2 - p1
        if gap % 2 == 0:
            m = (p1 + p2) // 2

            # Skip trivial gaps
            if gap not in (2, 4, 6):
                update_twin_midpoints(m)

                # Skip if gap is itself or multiple of a twin midpoint
                if not any(gap % t == 0 for t in twin_midpoints if t <= gap):
                    tested += 1
                    divisors = [t for t in twin_midpoints if t <= m and m % t == 0]
                    if divisors:
                        satisfied += 1
                        status = "satisfied"
                        divisor_str = ";".join(map(str, divisors))
                    else:
                        status = "FAIL"
                        divisor_str = ""
                        fails.append((p1, p2, gap, m))
                        # Print exceptions immediately
                        print(f"❌ EXCEPTION: p1={p1}, p2={p2}, gap={gap}, midpoint={m}")

                    # Log incrementally
                    with open(output_file, "a", newline="") as f:
                        writer = csv.writer(f)
                        writer.writerow([p1, p2, gap, m, status, divisor_str])

                    # Binned progress report
                    if tested - last_report >= report_interval:
                        elapsed = time.time() - start_time
                        print(f"Processed {tested:,} tests "
                              f"(satisfied={satisfied:,}, fails={len(fails)}) "
                              f"elapsed={elapsed:.1f}s")
                        last_report = tested

        # Step forward
        p1, p2 = p2, int(gmpy2.next_prime(p2))

    # Final report
    print(f"\n✅ Done. Tested {tested:,} midpoints, satisfied {satisfied:,}, fails={len(fails)}")
    if fails:
        print("❌ Counterexamples (first few):")
        for f in fails[:10]:
            print(f)
    else:
        print("🎉 No counterexamples in range", start, "to", finish)

    return {"tested": tested, "satisfied": satisfied, "fails": len(fails), "fail_examples": fails[:10]}


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Martini Glass Conjecture Tester")
    parser.add_argument("--start", type=int, default=2, help="Start number (inclusive)")
    parser.add_argument("--finish", type=int, default=100_000_000, help="Finish number (inclusive)")
    parser.add_argument("--output", type=str, default="results.csv", help="CSV output file path")
    parser.add_argument("--report-interval", type=int, default=100_000,
                        help="Progress report interval (in tests)")
    args = parser.parse_args()

    bosh_test(start=args.start, finish=args.finish,
              output_file=args.output, report_interval=args.report_interval)
